package com.change.service.common;

import com.change.core.web.request.AccountContext;
import com.change.exception.BizException;
import com.change.exception.ResultStatus;
import com.change.mapper.common.CommonMapper;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;

import java.util.List;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;

@Service
@Slf4j
public class CommonService {

    private final CommonMapper commonMapper;
    private final Integer pageSize = 20; // 页数据量

    public CommonService(CommonMapper commonMapper) {
        this.commonMapper = commonMapper;
    }


    /**
     * 万能查询语句
     *
     * @param sql
     * @return
     */
    public List<Map<Integer, String>> queryList(String sql) {
        sql = sql.toUpperCase();
        log.info("执行的查询语句：{}", sql);
        List<Map<Integer, String>> maps = commonMapper.queryList(sql);
        // 20条/分页 存在恶意查询操作,
        if (maps.size() > 20) {
            List<Map<Integer, String>> collect = maps.stream().limit(pageSize).collect(Collectors.toList());
            return collect;
        }
        return maps;
    }

    /**
     * 万能更新语句
     *
     * @param sql
     * @return
     */
    public Integer updateSql(String sql) {
        sql = sql.toUpperCase();
        log.info("执行的更新语句：{}", sql);
        if (sql.startsWith("UPDATE")) {
            boolean res = validateUpdate(sql);
            if (!res) {
                // 更新校验失败,抛出异常
                throw new BizException(ResultStatus.Common.UPDATE_VALIDATE_SQL);
            } else {
                sql = updateAppendAccountId(sql);
            }
        } else if (sql.startsWith("INSERT")) {
            boolean res = validateInsert(sql);
            if (!res) {
                // 新增校验失败,抛出异常
                throw new BizException(ResultStatus.Common.INSERT_VALIDATE_SQL);
            } else {
                sql = appendAccountId(sql);
            }
        } else {
            throw new BizException(ResultStatus.Common.INVALIDATE_SQL);
        }
        int integer = commonMapper.updateSql(sql);
        return integer;
    }

    /**
     * 为更新sql新增操作人id限制
     *
     * @param sql
     * @return
     */
    private String updateAppendAccountId(String sql) {
        // 1.获取当前用户id
        int accountId = AccountContext.getAccountId();
        StringBuilder sb = new StringBuilder(sql);
        sb.insert(sb.length(), " and owner_id = " + accountId);
        String finalSql = sb.toString().toUpperCase();
        log.info("修改格式语句后：{}", finalSql);
        return finalSql;
    }

    /**
     * 为SQL自动添加操作人信息
     *
     * @param sql
     * @return
     */
    private String appendAccountId(String sql) {
        // 1.获取当前用户id
        int accountId = AccountContext.getAccountId();
        StringBuilder sb = new StringBuilder(sql);
        int pos = sql.indexOf(")");
        sb.insert(pos, ",create_time,creator_id,update_id,owner_id,tenant_id");
        sb.insert(sb.length() - 1, ",NOW()" + "," + accountId + "," + accountId + "," + accountId + "," + accountId);
        String finalSql = sb.toString().toUpperCase();
        log.info("新增格式语句后：{}", finalSql);
        return finalSql;
    }

    /**
     * 2.新增语句校验
     *
     * @param sql
     * @return
     */
    private boolean validateInsert(String sql) {
        int count = 0;
        Pattern p = Pattern.compile("\\(");
        Matcher m = p.matcher(sql);
        while (m.find()) {
            count++;
        }
        if (count >= 3) {
            return false;
        }
        return true;
    }

    /**
     * 更新语句校验
     * 一次只能更新一条数据
     *
     * @param sql
     * @return
     */
    private boolean validateUpdate(String sql) {
        if (sql.indexOf("WHERE") == -1) {
            return false;
        } else {
            int i = sql.indexOf("WHERE");
            String where = sql.substring(i, sql.length());
            String table = sql.split(" ")[1];
            String querySql = "SELECT COUNT(1) FROM " + table + " " + where;
            log.info("查询更新语句数量：{}", querySql);
            Integer count = this.commonMapper.queryCount(querySql);
            if (count != 1) {
                return false;
            }
        }
        return true;
    }

}
